{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5eeje4O8fviH",
        "pycharm": {
          "name": "#%% md\n"
        }
      },
      "source": [
        "# Highway with SB3's DQN\n",
        "\n",
        "##  Warming up\n",
        "We start with a few useful installs and imports:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bzMSuJEOfviP",
        "pycharm": {
          "is_executing": false,
          "name": "#%%\n"
        }
      },
      "outputs": [],
      "source": [
        "# Install environment and agent\n",
        "!pip install highway-env\n",
        "# TODO: we use the bleeding edge version because the current stable version does not support the latest gym>=0.21 versions. Revert back to stable at the next SB3 release.\n",
        "!pip install git+https://github.com/DLR-RM/stable-baselines3\n",
        "\n",
        "# Environment\n",
        "import gymnasium as gym\n",
        "import highway_env\n",
        "\n",
        "# Agent\n",
        "from stable_baselines3 import DQN\n",
        "\n",
        "# Visualization utils\n",
        "%load_ext tensorboard\n",
        "import sys\n",
        "from tqdm.notebook import trange\n",
        "!pip install tensorboardx gym pyvirtualdisplay\n",
        "!apt-get install -y xvfb ffmpeg\n",
        "!git clone https://github.com/Farama-Foundation/HighwayEnv.git 2> /dev/null\n",
        "sys.path.insert(0, '/content/HighwayEnv/scripts/')\n",
        "from utils import record_videos, show_videos"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "collapsed": false,
        "id": "_wACJRDjqP-f",
        "pycharm": {
          "name": "#%% md\n"
        }
      },
      "source": [
        "## Training\n",
        "Run tensorboard locally to visualize training."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZSRTtNNzE5nL",
        "pycharm": {
          "name": "#%% \n"
        }
      },
      "outputs": [],
      "source": [
        "%tensorboard --logdir \"highway_dqn\""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Y5TOvonYqP-g",
        "pycharm": {
          "name": "#%% \n"
        }
      },
      "outputs": [],
      "source": [
        "model = DQN('MlpPolicy', 'highway-fast-v0',\n",
        "                policy_kwargs=dict(net_arch=[256, 256]),\n",
        "                learning_rate=5e-4,\n",
        "                buffer_size=15000,\n",
        "                learning_starts=200,\n",
        "                batch_size=32,\n",
        "                gamma=0.8,\n",
        "                train_freq=1,\n",
        "                gradient_steps=1,\n",
        "                target_update_interval=50,\n",
        "                exploration_fraction=0.7,\n",
        "                verbose=1,\n",
        "                tensorboard_log='highway_dqn/')\n",
        "model.learn(int(2e4))\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "n2Bu_Pqop0E7"
      },
      "source": [
        "## Testing\n",
        "\n",
        "Visualize a few episodes"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xOcOP7Of18T2",
        "pycharm": {
          "name": "#%%\n"
        }
      },
      "outputs": [],
      "source": [
        "env = gym.make('highway-fast-v0', render_mode='rgb_array')\n",
        "env = record_videos(env)\n",
        "for episode in trange(3, desc='Test episodes'):\n",
        "    (obs, info), done = env.reset(), False\n",
        "    while not done:\n",
        "        action, _ = model.predict(obs, deterministic=True)\n",
        "        obs, reward, done, truncated, info = env.step(int(action))\n",
        "env.close()\n",
        "show_videos()"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "name": "sb3_highway_dqn.ipynb",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.5"
    },
    "pycharm": {
      "stem_cell": {
        "cell_type": "raw",
        "metadata": {
          "collapsed": false
        },
        "source": []
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
